{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "V100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Training a Simple GAN with PyTorch\n",
        "- GAN for Generating simple MNIST Images\n",
        "- Good to have a GPU for faster training"
      ],
      "metadata": {
        "id": "r2dfGLNrTmHv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!rm -r track\n",
        "!mkdir track"
      ],
      "metadata": {
        "id": "k0IOFO4DUw06"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "4YxqP7LbiUrv"
      },
      "outputs": [],
      "source": [
        "# Torch Imports\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "\n",
        "# Torchvision Imports\n",
        "from torchvision import datasets, transforms\n",
        "from torchvision.utils import save_image\n",
        "\n",
        "# Device configuration; Good to have a GPU for GAN Training\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "device"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "K5f7nyCDieqA",
        "outputId": "ea1512ee-2144-46f5-8f56-064587527203"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "device(type='cuda')"
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Generator Network\n",
        "- 4 Fully Connected Layers\n",
        "- Output Dimension will be same as 28*28"
      ],
      "metadata": {
        "id": "SbdcNZOHTUgM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Generator(nn.Module):\n",
        "    def __init__(self, g_input_dim, g_output_dim):\n",
        "        super(Generator, self).__init__()\n",
        "        self.fc1 = nn.Linear(g_input_dim, 256)\n",
        "        self.fc2 = nn.Linear(256, 512)\n",
        "        self.fc3 = nn.Linear(512, 1024)\n",
        "        self.fc4 = nn.Linear(1024, g_output_dim)\n",
        "\n",
        "\n",
        "    # forward method\n",
        "    def forward(self, x):\n",
        "        x = F.leaky_relu(self.fc1(x), negative_slope=0.2)\n",
        "        x = F.leaky_relu(self.fc2(x), negative_slope=0.2)\n",
        "        x = F.leaky_relu(self.fc3(x), negative_slope=0.2)\n",
        "        return torch.tanh(self.fc4(x))\n"
      ],
      "metadata": {
        "id": "lYTu38GFik4Z"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Discriminator Network\n",
        "- 4 Fully Connected Layers\n",
        "- Output is a Single Neuron, 1=Real Image, 0=Fake Image"
      ],
      "metadata": {
        "id": "G-JIxeU0TcyI"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Discriminator(nn.Module):\n",
        "    def __init__(self, d_input_dim):\n",
        "        super(Discriminator, self).__init__()\n",
        "        self.fc1 = nn.Linear(d_input_dim, 1024)\n",
        "        self.fc2 = nn.Linear(1024, 512)\n",
        "        self.fc3 = nn.Linear(512, 256)\n",
        "        self.fc4 = nn.Linear(256, 1)\n",
        "\n",
        "    # forward method\n",
        "    def forward(self, x):\n",
        "        x = F.leaky_relu(self.fc1(x), negative_slope=0.2)\n",
        "        x = F.leaky_relu(self.fc2(x), negative_slope=0.2)\n",
        "        x = F.leaky_relu(self.fc3(x), negative_slope=0.2)\n",
        "        return torch.sigmoid(self.fc4(x))\n"
      ],
      "metadata": {
        "id": "Is-ZNi1tio-_"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Instantiating Models\n",
        "- Using a random noise of 100 dimensions\n",
        "- Using BCE Loss for Training\n",
        "- Using Adam Optimizer for Training"
      ],
      "metadata": {
        "id": "nvYTEfyMTJKV"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Models\n",
        "z_dim = 100 # Dimension of the Input Noise to the Generator\n",
        "mnist_dim = 28 * 28 # As we have 28x28 pixel images in MNIST\n",
        "G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)\n",
        "D = Discriminator(mnist_dim).to(device)\n",
        "\n",
        "# loss\n",
        "criterion = nn.BCELoss()\n",
        "\n",
        "# optimizer\n",
        "lr = 0.0005\n",
        "G_optimizer = optim.Adam(G.parameters(), lr = lr)\n",
        "D_optimizer = optim.Adam(D.parameters(), lr = lr)"
      ],
      "metadata": {
        "id": "6oesO7aEnZu3"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Loading & Normalizing MNIST Dataset"
      ],
      "metadata": {
        "id": "MQD5--pkTCwi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Batch Size for Training\n",
        "batch_size = 100\n",
        "\n",
        "# Transforming Images to Tensor and then Normalizing the values to be between 0 and 1, with mean 0.5 and std 0.5\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Normalize(mean=(0.5,), std=(0.5,))])\n",
        "\n",
        "# Loading the Training Dataset and applying Transformation\n",
        "train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)\n",
        "test_dataset = datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)\n",
        "\n",
        "# Data Loader which will be used as input to the models.\n",
        "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
        "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)\n"
      ],
      "metadata": {
        "id": "Mi4nn196i4bi"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Method for Training Discriminator"
      ],
      "metadata": {
        "id": "uG5014WuTANh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def train_discriminator(mnist_data_batch):\n",
        "    #===================Discriminator-Trainer===================#\n",
        "    D.zero_grad()\n",
        "\n",
        "    # train discriminator on real\n",
        "    x_real, y_real = mnist_data_batch.view(-1, mnist_dim), torch.full((batch_size, 1), 0.90)\n",
        "    x_real, y_real = x_real.to(device), y_real.to(device)\n",
        "\n",
        "    # Loss for Real MNIST Data\n",
        "    D_output = D(x_real)\n",
        "    D_real_loss = criterion(D_output, y_real)\n",
        "\n",
        "    # Generate Data from Generator Network for Training\n",
        "    z = torch.randn(batch_size, z_dim).to(device) # Random Noise as Input to G\n",
        "    x_fake, y_fake = G(z), torch.full((batch_size, 1), 0.1).to(device)\n",
        "\n",
        "    # Loss for Fake Data from Generator\n",
        "    D_output = D(x_fake)\n",
        "    D_fake_loss = criterion(D_output, y_fake)\n",
        "\n",
        "    # Updating only D's weights, so training on D on total loss\n",
        "    D_loss = D_real_loss + D_fake_loss\n",
        "    D_loss.backward()\n",
        "    D_optimizer.step()\n",
        "\n",
        "    return  D_loss.data.item()\n"
      ],
      "metadata": {
        "id": "RZkoCFbQi8Et"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Method for Training Generator"
      ],
      "metadata": {
        "id": "svYCG8BRS8Go"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def train_generator():\n",
        "    #=======================Generator-Trainer=======================#\n",
        "    G.zero_grad()\n",
        "\n",
        "    # Generate Random Noise to use as Input to the Generator\n",
        "    z = torch.randn(batch_size, z_dim).to(device)\n",
        "    # The final label in this case in 1(True), i.e the Discriminator Model\n",
        "    # Thinks these are real images therefore the losses for Generator\n",
        "    # Network should update the weights of G in such a way that it produces more\n",
        "    # Real looking images.\n",
        "    y = torch.full((batch_size, 1), 0.90).to(device)\n",
        "\n",
        "    # Generate Images\n",
        "    G_output = G(z)\n",
        "    # Get output from Discriminator for the Generated Images\n",
        "    D_output = D(G_output)\n",
        "    # Compute Loss\n",
        "    G_loss = criterion(D_output, y)\n",
        "    # G_loss = -D_output.mean()\n",
        "\n",
        "    # Updating only G's weights\n",
        "    G_loss.backward()\n",
        "    G_optimizer.step()\n",
        "\n",
        "    return G_loss.data.item()"
      ],
      "metadata": {
        "id": "5_ZOsm6zjAj5"
      },
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Main Training Loop"
      ],
      "metadata": {
        "id": "M0zmaw4MS6bP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Total number of epochs for Training\n",
        "n_epoch = 200\n",
        "\n",
        "losses = []\n",
        "\n",
        "for epoch in range(n_epoch):\n",
        "    D_losses, G_losses = [], []\n",
        "    # For Every Batch in the MNIST dataset\n",
        "    # We train the Discriminator First, then the Generator\n",
        "    for batch_idx, (mnist_input_data, _) in enumerate(train_loader):\n",
        "        D_losses.append(train_discriminator(mnist_input_data))\n",
        "\n",
        "        # Weight Clamping\n",
        "        for p in D.parameters():\n",
        "          p.data.clamp_(-0.01, 0.01)\n",
        "\n",
        "        if batch_idx % 2 == 0:\n",
        "          G_losses.append(train_generator())\n",
        "\n",
        "    if epoch % 10 == 0:\n",
        "        test_z = torch.randn(8, z_dim).to(device)\n",
        "        generated = G(test_z)\n",
        "        save_image(generated.view(generated.size(0), 1, 28, 28), f'./track/sample_images_{epoch}' + '.png')\n",
        "\n",
        "    # Logging the Loss values from Discriminator and Generator\n",
        "    loss_d = torch.mean(torch.FloatTensor(D_losses))\n",
        "    loss_g = torch.mean(torch.FloatTensor(G_losses))\n",
        "    losses.append((loss_d, loss_g))\n",
        "    print(f'[{epoch}/{n_epoch}]: loss_d: {loss_d}, loss_g: {loss_g}')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BNEl9XfflU21",
        "outputId": "99ffffa6-a0a4-4873-8836-3142845a6f9e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[0/200]: loss_d: 0.8678494095802307, loss_g: 1.9354009628295898\n",
            "[1/200]: loss_d: 0.6945022344589233, loss_g: 2.118706464767456\n",
            "[2/200]: loss_d: 0.6817149519920349, loss_g: 2.1160173416137695\n",
            "[3/200]: loss_d: 0.6955651044845581, loss_g: 2.1698546409606934\n",
            "[4/200]: loss_d: 0.7377318143844604, loss_g: 2.09809947013855\n",
            "[5/200]: loss_d: 0.7426372170448303, loss_g: 1.96819007396698\n",
            "[6/200]: loss_d: 0.7336576581001282, loss_g: 1.99467933177948\n",
            "[7/200]: loss_d: 0.7046687006950378, loss_g: 2.0825297832489014\n",
            "[8/200]: loss_d: 0.7505753040313721, loss_g: 2.0209078788757324\n",
            "[9/200]: loss_d: 0.7307223677635193, loss_g: 2.046144485473633\n",
            "[10/200]: loss_d: 0.7518192529678345, loss_g: 1.891499400138855\n",
            "[11/200]: loss_d: 0.7489187717437744, loss_g: 1.8791155815124512\n",
            "[12/200]: loss_d: 0.7474170923233032, loss_g: 1.8137696981430054\n",
            "[13/200]: loss_d: 0.7766727805137634, loss_g: 1.793921947479248\n",
            "[14/200]: loss_d: 0.79848313331604, loss_g: 1.7094461917877197\n",
            "[15/200]: loss_d: 0.8160367608070374, loss_g: 1.6751149892807007\n",
            "[16/200]: loss_d: 0.8122260570526123, loss_g: 1.7098844051361084\n",
            "[17/200]: loss_d: 0.8105064034461975, loss_g: 1.6491848230361938\n",
            "[18/200]: loss_d: 0.8460466265678406, loss_g: 1.5804144144058228\n",
            "[19/200]: loss_d: 0.8607600927352905, loss_g: 1.5041625499725342\n",
            "[20/200]: loss_d: 0.879481315612793, loss_g: 1.5266501903533936\n",
            "[21/200]: loss_d: 0.8593348264694214, loss_g: 1.4807443618774414\n",
            "[22/200]: loss_d: 0.8635353446006775, loss_g: 1.4727977514266968\n",
            "[23/200]: loss_d: 0.8677745461463928, loss_g: 1.5209871530532837\n",
            "[24/200]: loss_d: 0.8734594583511353, loss_g: 1.4630192518234253\n",
            "[25/200]: loss_d: 0.8876548409461975, loss_g: 1.430591344833374\n",
            "[26/200]: loss_d: 0.8962651491165161, loss_g: 1.4056795835494995\n",
            "[27/200]: loss_d: 0.8826575875282288, loss_g: 1.4605122804641724\n",
            "[28/200]: loss_d: 0.8812916874885559, loss_g: 1.4460475444793701\n",
            "[29/200]: loss_d: 0.8905916213989258, loss_g: 1.4520634412765503\n",
            "[30/200]: loss_d: 0.9047392010688782, loss_g: 1.4209439754486084\n",
            "[31/200]: loss_d: 0.9060031175613403, loss_g: 1.4329609870910645\n",
            "[32/200]: loss_d: 0.9033512473106384, loss_g: 1.3785659074783325\n",
            "[33/200]: loss_d: 0.9315887689590454, loss_g: 1.371907353401184\n",
            "[34/200]: loss_d: 0.933129072189331, loss_g: 1.333664894104004\n",
            "[35/200]: loss_d: 0.9307510256767273, loss_g: 1.3389054536819458\n",
            "[36/200]: loss_d: 0.9433754682540894, loss_g: 1.3440548181533813\n",
            "[37/200]: loss_d: 0.9494260549545288, loss_g: 1.3095784187316895\n",
            "[38/200]: loss_d: 0.9628462791442871, loss_g: 1.2197757959365845\n",
            "[39/200]: loss_d: 0.9643799066543579, loss_g: 1.2667912244796753\n",
            "[40/200]: loss_d: 0.9652414917945862, loss_g: 1.2855380773544312\n",
            "[41/200]: loss_d: 0.9695826172828674, loss_g: 1.2260864973068237\n",
            "[42/200]: loss_d: 0.9721906781196594, loss_g: 1.2536559104919434\n",
            "[43/200]: loss_d: 0.9843592047691345, loss_g: 1.2571946382522583\n",
            "[44/200]: loss_d: 0.9978849291801453, loss_g: 1.1914997100830078\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Generating Images from our Trained Generator Network"
      ],
      "metadata": {
        "id": "-7judlTOS247"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with torch.no_grad():\n",
        "    test_z = torch.randn(8, z_dim).to(device)\n",
        "    generated = G(test_z)\n",
        "\n",
        "    save_image(generated.view(generated.size(0), 1, 28, 28), './track/sample_images_200' + '.png')"
      ],
      "metadata": {
        "id": "gxVOotx_oxeS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt"
      ],
      "metadata": {
        "id": "fbhmXPVx-vgR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plt.imshow(generated.view(generated.size(0), 1, 28, 28).cpu()[2].permute(1,2,0))"
      ],
      "metadata": {
        "id": "WpZVI2Snuihl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Plotting Losses from the Training"
      ],
      "metadata": {
        "id": "OEPmb8tlSyU2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "plt.plot(range(1, len(losses)+1), [_[0] for _ in losses], label='discriminator_loss')\n",
        "plt.plot(range(1, len(losses)+1), [_[1] for _ in losses], label='generator_loss')\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "ylatKU9TSI1s"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "oICcnSSvSPUD"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}